package com.lw.leetcode.tree.c;

import com.lw.leetcode.tree.TreeNode;

import java.util.*;

/**
 * Created with IntelliJ IDEA.
 * <p>
 * 1932. 合并多棵二叉搜索树
 *
 * @author liw
 * @version 1.0
 * @date 2022/7/13 15:16
 */
public class CanMerge {


    public static void main(String[] args) {
        CanMerge test = new CanMerge();

        // [[1,null,3],[3,1],[4,2]]
        // [[2,1],[3,2,5],[5,4]]
        TreeNode a = TreeNode.getInstance16();
        TreeNode b = TreeNode.getInstance17();
        TreeNode c = TreeNode.getInstance18();
        List<TreeNode> list = Arrays.asList(a, b, c);

        TreeNode treeNode = test.canMerge(list);
        System.out.println(treeNode);

    }


    private Map<Integer, TreeNode> map = new HashMap<>();

    public TreeNode canMerge(List<TreeNode> trees) {
        Set<Integer> roots = new HashSet<>();
        Set<Integer> nodes = new HashSet<>();
        for (TreeNode tree : trees) {
            map.put(tree.val, tree);
            roots.add(tree.val);
            if (tree.left != null) {
                TreeNode treeNode = map.get(tree.left.val);
                if (treeNode != null && treeNode.right != null && treeNode.right.val >= tree.val) {
                    return null;
                }
                nodes.add(tree.left.val);
            }
            if (tree.right != null) {
                TreeNode treeNode = map.get(tree.right.val);
                if (treeNode != null && treeNode.left != null &&treeNode.left.val <= tree.val) {
                    return null;
                }
                nodes.add(tree.right.val);
            }
        }
        roots.removeAll(nodes);
        if (roots.size() != 1) {
            return null;
        }
        int v = 0;
        for (Integer i : roots) {
            v = i;
        }
        TreeNode root = map.get(v);
        if (find(root, 0, Integer.MAX_VALUE)) {
            return root;
        }
        return null;
    }

    private boolean find(TreeNode node, int left, int right) {
        int val = node.val;
        TreeNode treeNode = map.get(val);
        if (treeNode == null) {
            return true;
        }
        map.remove(val);
        if (treeNode.left != null) {
            int v = treeNode.left.val;
            node.left = treeNode.left;
            if (v <= left || !find(node.left, left, val)) {
                return false;
            }
        }
        if (treeNode.right != null) {
            int v = treeNode.right.val;
            node.right = treeNode.right;
            return v < right && find(node.right, val, right);
        }
        return true;
    }

}
